{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HZiF5lbumA7j"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "KsOkK8O69PyT"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eNj0_BTFk479"
      },
      "source": [
        "# TF Lattice Premade Models"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "T3qE8F5toE28"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/lattice/tutorials/premade_models\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/lattice/blob/master/docs/tutorials/premade_models.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/lattice/blob/master/docs/tutorials/premade_models.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/lattice/docs/tutorials/premade_models.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HEuRMAUOlFZa"
      },
      "source": [
        "## Overview\n",
        "\n",
        "Premade Models are quick and easy ways to build TFL `tf.keras.model` instances for typical use cases. This guide outlines the steps needed to construct a TFL Premade Model and train/test it. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f2--Yq21lhRe"
      },
      "source": [
        "## Setup\n",
        "\n",
        "Installing TF Lattice package:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XizqBCyXky4y"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "!pip install tensorflow-lattice pydot"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2oKJPy5tloOB"
      },
      "source": [
        "Importing required packages:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9wZWJJggk4al"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "import copy\n",
        "import logging\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import sys\n",
        "import tensorflow_lattice as tfl\n",
        "logging.disable(sys.maxsize)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kpJJSS7YmLbG"
      },
      "source": [
        "Downloading the UCI Statlog (Heart) dataset:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AYTcybljmQJm"
      },
      "outputs": [],
      "source": [
        "csv_file = tf.keras.utils.get_file(\n",
        "    'heart.csv', 'http://storage.googleapis.com/download.tensorflow.org/data/heart.csv')\n",
        "df = pd.read_csv(csv_file)\n",
        "train_size = int(len(df) * 0.8)\n",
        "train_dataframe = df[:train_size]\n",
        "test_dataframe = df[train_size:]\n",
        "df.head()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ODe0oavWmtAi"
      },
      "source": [
        "Extract and convert features and labels to tensors:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3ae-Mx-PnGGL"
      },
      "outputs": [],
      "source": [
        "# Features:\n",
        "# - age\n",
        "# - sex\n",
        "# - cp        chest pain type (4 values)\n",
        "# - trestbps  resting blood pressure\n",
        "# - chol      serum cholestoral in mg/dl\n",
        "# - fbs       fasting blood sugar > 120 mg/dl\n",
        "# - restecg   resting electrocardiographic results (values 0,1,2)\n",
        "# - thalach   maximum heart rate achieved\n",
        "# - exang     exercise induced angina\n",
        "# - oldpeak   ST depression induced by exercise relative to rest\n",
        "# - slope     the slope of the peak exercise ST segment\n",
        "# - ca        number of major vessels (0-3) colored by flourosopy\n",
        "# - thal      3 = normal; 6 = fixed defect; 7 = reversable defect\n",
        "#\n",
        "# This ordering of feature names will be the exact same order that we construct\n",
        "# our model to expect.\n",
        "feature_names = [\n",
        "    'age', 'sex', 'cp', 'chol', 'fbs', 'trestbps', 'thalach', 'restecg',\n",
        "    'exang', 'oldpeak', 'slope', 'ca', 'thal'\n",
        "]\n",
        "feature_name_indices = {name: index for index, name in enumerate(feature_names)}\n",
        "# This is the vocab list and mapping we will use for the 'thal' categorical\n",
        "# feature.\n",
        "thal_vocab_list = ['normal', 'fixed', 'reversible']\n",
        "thal_map = {category: i for i, category in enumerate(thal_vocab_list)}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x5p3OgC-m4TW"
      },
      "outputs": [],
      "source": [
        "# Custom function for converting thal categories to buckets\n",
        "def convert_thal_features(thal_features):\n",
        "  # Note that two examples in the test set are already converted.\n",
        "  return np.array([\n",
        "      thal_map[feature] if feature in thal_vocab_list else feature\n",
        "      for feature in thal_features\n",
        "  ])\n",
        "\n",
        "\n",
        "# Custom function for extracting each feature.\n",
        "def extract_features(dataframe,\n",
        "                     label_name='target',\n",
        "                     feature_names=feature_names):\n",
        "  features = []\n",
        "  for feature_name in feature_names:\n",
        "    if feature_name == 'thal':\n",
        "      features.append(\n",
        "          convert_thal_features(dataframe[feature_name].values).astype(float))\n",
        "    else:\n",
        "      features.append(dataframe[feature_name].values.astype(float))\n",
        "  labels = dataframe[label_name].values.astype(float)\n",
        "  return features, labels"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7DgoAgkIm8tr"
      },
      "outputs": [],
      "source": [
        "train_xs, train_ys = extract_features(train_dataframe)\n",
        "test_xs, test_ys = extract_features(test_dataframe)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qcguGFRcFgCQ"
      },
      "outputs": [],
      "source": [
        "# Let's define our label minimum and maximum.\n",
        "min_label, max_label = float(np.min(train_ys)), float(np.max(train_ys))\n",
        "# Our lattice models may have predictions above 1.0 due to numerical errors.\n",
        "# We can subtract this small epsilon value from our output_max to make sure we\n",
        "# do not predict values outside of our label bound.\n",
        "numerical_error_epsilon = 1e-5"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oyOrtol7mW9r"
      },
      "source": [
        "Setting the default values used for training in this guide:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ns8pH2AnmgAC"
      },
      "outputs": [],
      "source": [
        "LEARNING_RATE = 0.01\n",
        "BATCH_SIZE = 128\n",
        "NUM_EPOCHS = 500\n",
        "PREFITTING_NUM_EPOCHS = 10"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ix2elMrGmiWX"
      },
      "source": [
        "## Feature Configs\n",
        "\n",
        "Feature calibration and per-feature configurations are set using [tfl.configs.FeatureConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/FeatureConfig). Feature configurations include monotonicity constraints, per-feature regularization (see [tfl.configs.RegularizerConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/RegularizerConfig)), and lattice sizes for lattice models.\n",
        "\n",
        "Note that we must fully specify the feature config for any feature that we want our model to recognize. Otherwise the model will have no way of knowing that such a feature exists."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WLSfZ5G7-YT_"
      },
      "source": [
        "### Compute Quantiles\n",
        "\n",
        "Although the default setting for `pwl_calibration_input_keypoints` in `tfl.configs.FeatureConfig` is 'quantiles', for premade models we have to manually define the input keypoints. To do so, we first define our own helper function for computing quantiles. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-LqqEp3k-06d"
      },
      "outputs": [],
      "source": [
        "def compute_quantiles(features,\n",
        "                      num_keypoints=10,\n",
        "                      clip_min=None,\n",
        "                      clip_max=None,\n",
        "                      missing_value=None):\n",
        "  # Clip min and max if desired.\n",
        "  if clip_min is not None:\n",
        "    features = np.maximum(features, clip_min)\n",
        "    features = np.append(features, clip_min)\n",
        "  if clip_max is not None:\n",
        "    features = np.minimum(features, clip_max)\n",
        "    features = np.append(features, clip_max)\n",
        "  # Make features unique.\n",
        "  unique_features = np.unique(features)\n",
        "  # Remove missing values if specified.\n",
        "  if missing_value is not None:\n",
        "    unique_features = np.delete(unique_features,\n",
        "                                np.where(unique_features == missing_value))\n",
        "  # Compute and return quantiles over unique non-missing feature values.\n",
        "  return np.quantile(\n",
        "      unique_features,\n",
        "      np.linspace(0., 1., num=num_keypoints),\n",
        "      interpolation='nearest').astype(float)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ePWXuDH7-1i1"
      },
      "source": [
        "### Defining Our Feature Configs\n",
        "\n",
        "Now that we can compute our quantiles, we define a feature config for each feature that we want our model to take as input."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8y27RmHIrSBn"
      },
      "outputs": [],
      "source": [
        "# Feature configs are used to specify how each feature is calibrated and used.\n",
        "feature_configs = [\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='age',\n",
        "        lattice_size=3,\n",
        "        monotonicity='increasing',\n",
        "        # We must set the keypoints manually.\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        pwl_calibration_input_keypoints=compute_quantiles(\n",
        "            train_xs[feature_name_indices['age']],\n",
        "            num_keypoints=5,\n",
        "            clip_max=100),\n",
        "        # Per feature regularization.\n",
        "        regularizer_configs=[\n",
        "            tfl.configs.RegularizerConfig(name='calib_wrinkle', l2=0.1),\n",
        "        ],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='sex',\n",
        "        num_buckets=2,\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='cp',\n",
        "        monotonicity='increasing',\n",
        "        # Keypoints that are uniformly spaced.\n",
        "        pwl_calibration_num_keypoints=4,\n",
        "        pwl_calibration_input_keypoints=np.linspace(\n",
        "            np.min(train_xs[feature_name_indices['cp']]),\n",
        "            np.max(train_xs[feature_name_indices['cp']]),\n",
        "            num=4),\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='chol',\n",
        "        monotonicity='increasing',\n",
        "        # Explicit input keypoints initialization.\n",
        "        pwl_calibration_input_keypoints=[126.0, 210.0, 247.0, 286.0, 564.0],\n",
        "        # Calibration can be forced to span the full output range by clamping.\n",
        "        pwl_calibration_clamp_min=True,\n",
        "        pwl_calibration_clamp_max=True,\n",
        "        # Per feature regularization.\n",
        "        regularizer_configs=[\n",
        "            tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-4),\n",
        "        ],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='fbs',\n",
        "        # Partial monotonicity: output(0) <= output(1)\n",
        "        monotonicity=[(0, 1)],\n",
        "        num_buckets=2,\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='trestbps',\n",
        "        monotonicity='decreasing',\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        pwl_calibration_input_keypoints=compute_quantiles(\n",
        "            train_xs[feature_name_indices['trestbps']], num_keypoints=5),\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='thalach',\n",
        "        monotonicity='decreasing',\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        pwl_calibration_input_keypoints=compute_quantiles(\n",
        "            train_xs[feature_name_indices['thalach']], num_keypoints=5),\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='restecg',\n",
        "        # Partial monotonicity: output(0) <= output(1), output(0) <= output(2)\n",
        "        monotonicity=[(0, 1), (0, 2)],\n",
        "        num_buckets=3,\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='exang',\n",
        "        # Partial monotonicity: output(0) <= output(1)\n",
        "        monotonicity=[(0, 1)],\n",
        "        num_buckets=2,\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='oldpeak',\n",
        "        monotonicity='increasing',\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        pwl_calibration_input_keypoints=compute_quantiles(\n",
        "            train_xs[feature_name_indices['oldpeak']], num_keypoints=5),\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='slope',\n",
        "        # Partial monotonicity: output(0) <= output(1), output(1) <= output(2)\n",
        "        monotonicity=[(0, 1), (1, 2)],\n",
        "        num_buckets=3,\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='ca',\n",
        "        monotonicity='increasing',\n",
        "        pwl_calibration_num_keypoints=4,\n",
        "        pwl_calibration_input_keypoints=compute_quantiles(\n",
        "            train_xs[feature_name_indices['ca']], num_keypoints=4),\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='thal',\n",
        "        # Partial monotonicity:\n",
        "        # output(normal) <= output(fixed)\n",
        "        # output(normal) <= output(reversible)\n",
        "        monotonicity=[('normal', 'fixed'), ('normal', 'reversible')],\n",
        "        num_buckets=3,\n",
        "        # We must specify the vocabulary list in order to later set the\n",
        "        # monotonicities since we used names and not indices.\n",
        "        vocabulary_list=thal_vocab_list,\n",
        "    ),\n",
        "]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-XuAnP_-vyK6"
      },
      "source": [
        "Next we need to make sure to properly set the monotonicities for features where we used a custom vocabulary (such as 'thal' above)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZIn2-EfGv--m"
      },
      "outputs": [],
      "source": [
        "tfl.premade_lib.set_categorical_monotonicities(feature_configs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Mx50YgWMcxC4"
      },
      "source": [
        "## Calibrated Linear Model\n",
        "\n",
        "To construct a TFL premade model, first construct a model configuration from [tfl.configs](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs). A calibrated linear model is constructed using the [tfl.configs.CalibratedLinearConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/CalibratedLinearConfig). It applies piecewise-linear and categorical calibration on the input features, followed by a linear combination and an optional output piecewise-linear calibration. When using output calibration or when output bounds are specified, the linear layer will apply weighted averaging on calibrated inputs.\n",
        "\n",
        "This example creates a calibrated linear model on the first 5 features."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UvMDJKqTc1vC"
      },
      "outputs": [],
      "source": [
        "# Model config defines the model structure for the premade model.\n",
        "linear_model_config = tfl.configs.CalibratedLinearConfig(\n",
        "    feature_configs=feature_configs[:5],\n",
        "    use_bias=True,\n",
        "    # We must set the output min and max to that of the label.\n",
        "    output_min=min_label,\n",
        "    output_max=max_label,\n",
        "    output_calibration=True,\n",
        "    output_calibration_num_keypoints=10,\n",
        "    output_initialization=np.linspace(min_label, max_label, num=10),\n",
        "    regularizer_configs=[\n",
        "        # Regularizer for the output calibrator.\n",
        "        tfl.configs.RegularizerConfig(name='output_calib_hessian', l2=1e-4),\n",
        "    ])\n",
        "# A CalibratedLinear premade model constructed from the given model config.\n",
        "linear_model = tfl.premade.CalibratedLinear(linear_model_config)\n",
        "# Let's plot our model.\n",
        "tf.keras.utils.plot_model(linear_model, show_layer_names=False, rankdir='LR')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3MC3-AyX00-A"
      },
      "source": [
        "Now, as with any other [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model), we compile and fit the model to our data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hPlEK-yG1B-U"
      },
      "outputs": [],
      "source": [
        "linear_model.compile(\n",
        "    loss=tf.keras.losses.BinaryCrossentropy(),\n",
        "    metrics=[tf.keras.metrics.AUC()],\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))\n",
        "linear_model.fit(\n",
        "    train_xs[:5],\n",
        "    train_ys,\n",
        "    epochs=NUM_EPOCHS,\n",
        "    batch_size=BATCH_SIZE,\n",
        "    verbose=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OG2ua0MGAkoi"
      },
      "source": [
        "After training our model, we can evaluate it on our test set."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HybGTvXxAoxV"
      },
      "outputs": [],
      "source": [
        "print('Test Set Evaluation...')\n",
        "print(linear_model.evaluate(test_xs[:5], test_ys))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jAAJK-wlc15S"
      },
      "source": [
        "## Calibrated Lattice Model\n",
        "\n",
        "A calibrated lattice model is constructed using [tfl.configs.CalibratedLatticeConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/CalibratedLatticeConfig). A calibrated lattice model applies piecewise-linear and categorical calibration on the input features, followed by a lattice model and an optional output piecewise-linear calibration.\n",
        "\n",
        "This example creates a calibrated lattice model on the first 5 features."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "u7gNcrMtc4Lp"
      },
      "outputs": [],
      "source": [
        "# This is a calibrated lattice model: inputs are calibrated, then combined\n",
        "# non-linearly using a lattice layer.\n",
        "lattice_model_config = tfl.configs.CalibratedLatticeConfig(\n",
        "    feature_configs=feature_configs[:5],\n",
        "    output_min=min_label,\n",
        "    output_max=max_label - numerical_error_epsilon,\n",
        "    output_initialization=[min_label, max_label],\n",
        "    regularizer_configs=[\n",
        "        # Torsion regularizer applied to the lattice to make it more linear.\n",
        "        tfl.configs.RegularizerConfig(name='torsion', l2=1e-2),\n",
        "        # Globally defined calibration regularizer is applied to all features.\n",
        "        tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-2),\n",
        "    ])\n",
        "# A CalibratedLattice premade model constructed from the given model config.\n",
        "lattice_model = tfl.premade.CalibratedLattice(lattice_model_config)\n",
        "# Let's plot our model.\n",
        "tf.keras.utils.plot_model(lattice_model, show_layer_names=False, rankdir='LR')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nmc3TUIIGGoH"
      },
      "source": [
        "As before, we compile, fit, and evaluate our model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vIjOQGD2Gp_Z"
      },
      "outputs": [],
      "source": [
        "lattice_model.compile(\n",
        "    loss=tf.keras.losses.BinaryCrossentropy(),\n",
        "    metrics=[tf.keras.metrics.AUC()],\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))\n",
        "lattice_model.fit(\n",
        "    train_xs[:5],\n",
        "    train_ys,\n",
        "    epochs=NUM_EPOCHS,\n",
        "    batch_size=BATCH_SIZE,\n",
        "    verbose=False)\n",
        "print('Test Set Evaluation...')\n",
        "print(lattice_model.evaluate(test_xs[:5], test_ys))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bx74CD4Cc4T3"
      },
      "source": [
        "## Calibrated Lattice Ensemble Model\n",
        "\n",
        "When the number of features is large, you can use an ensemble model, which creates multiple smaller lattices for subsets of the features and averages their output instead of creating just a single huge lattice. Ensemble lattice models are constructed using [tfl.configs.CalibratedLatticeEnsembleConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/CalibratedLatticeEnsembleConfig). A calibrated lattice ensemble model applies piecewise-linear and categorical calibration on the input feature, followed by an ensemble of lattice models and an optional output piecewise-linear calibration."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mbg4lsKqnEkV"
      },
      "source": [
        "### Explicit Lattice Ensemble Initialization\n",
        "\n",
        "If you already know which subsets of features you want to feed into your lattices, then you can explicitly set the lattices using feature names. This example creates a calibrated lattice ensemble model with 5 lattices and 3 features per lattice."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yu8Twg8mdJ18"
      },
      "outputs": [],
      "source": [
        "# This is a calibrated lattice ensemble model: inputs are calibrated, then\n",
        "# combined non-linearly and averaged using multiple lattice layers.\n",
        "explicit_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n",
        "    feature_configs=feature_configs,\n",
        "    lattices=[['trestbps', 'chol', 'ca'], ['fbs', 'restecg', 'thal'],\n",
        "              ['fbs', 'cp', 'oldpeak'], ['exang', 'slope', 'thalach'],\n",
        "              ['restecg', 'age', 'sex']],\n",
        "    num_lattices=5,\n",
        "    lattice_rank=3,\n",
        "    output_min=min_label,\n",
        "    output_max=max_label - numerical_error_epsilon,\n",
        "    output_initialization=[min_label, max_label])\n",
        "# A CalibratedLatticeEnsemble premade model constructed from the given\n",
        "# model config.\n",
        "explicit_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(\n",
        "    explicit_ensemble_model_config)\n",
        "# Let's plot our model.\n",
        "tf.keras.utils.plot_model(\n",
        "    explicit_ensemble_model, show_layer_names=False, rankdir='LR')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PJYR0i6MMDyh"
      },
      "source": [
        "As before, we compile, fit, and evaluate our model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "capt98IOMHEm"
      },
      "outputs": [],
      "source": [
        "explicit_ensemble_model.compile(\n",
        "    loss=tf.keras.losses.BinaryCrossentropy(),\n",
        "    metrics=[tf.keras.metrics.AUC()],\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))\n",
        "explicit_ensemble_model.fit(\n",
        "    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)\n",
        "print('Test Set Evaluation...')\n",
        "print(explicit_ensemble_model.evaluate(test_xs, test_ys))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VnI70C9gdKQw"
      },
      "source": [
        "### Random Lattice Ensemble\n",
        "\n",
        "If you are not sure which subsets of features to feed into your lattices, another option is to use random subsets of features for each lattice. This example creates a calibrated lattice ensemble model with 5 lattices and 3 features per lattice."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7EhWrQaPIXj8"
      },
      "outputs": [],
      "source": [
        "# This is a calibrated lattice ensemble model: inputs are calibrated, then\n",
        "# combined non-linearly and averaged using multiple lattice layers.\n",
        "random_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n",
        "    feature_configs=feature_configs,\n",
        "    lattices='random',\n",
        "    num_lattices=5,\n",
        "    lattice_rank=3,\n",
        "    output_min=min_label,\n",
        "    output_max=max_label - numerical_error_epsilon,\n",
        "    output_initialization=[min_label, max_label],\n",
        "    random_seed=42)\n",
        "# Now we must set the random lattice structure and construct the model.\n",
        "tfl.premade_lib.set_random_lattice_ensemble(random_ensemble_model_config)\n",
        "# A CalibratedLatticeEnsemble premade model constructed from the given\n",
        "# model config.\n",
        "random_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(\n",
        "    random_ensemble_model_config)\n",
        "# Let's plot our model.\n",
        "tf.keras.utils.plot_model(\n",
        "    random_ensemble_model, show_layer_names=False, rankdir='LR')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sbxcIF0PJUDc"
      },
      "source": [
        "As before, we compile, fit, and evaluate our model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "w0YdCDyGJY1G"
      },
      "outputs": [],
      "source": [
        "random_ensemble_model.compile(\n",
        "    loss=tf.keras.losses.BinaryCrossentropy(),\n",
        "    metrics=[tf.keras.metrics.AUC()],\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))\n",
        "random_ensemble_model.fit(\n",
        "    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)\n",
        "print('Test Set Evaluation...')\n",
        "print(random_ensemble_model.evaluate(test_xs, test_ys))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZhJWe7fZIs4-"
      },
      "source": [
        "### RTL Layer Random Lattice Ensemble\n",
        "\n",
        "When using a random lattice ensemble, you can specify that the model use a single `tfl.layers.RTL` layer. We note that `tfl.layers.RTL` only supports monotonicity constraints and must have the same lattice size for all features and no per-feature regularization. Note that using a `tfl.layers.RTL` layer lets you scale to much larger ensembles than using separate `tfl.layers.Lattice` instances.\n",
        "\n",
        "This example creates a calibrated lattice ensemble model with 5 lattices and 3 features per lattice."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0PC9oRFYJMF_"
      },
      "outputs": [],
      "source": [
        "# Make sure our feature configs have the same lattice size, no per-feature\n",
        "# regularization, and only monotonicity constraints.\n",
        "rtl_layer_feature_configs = copy.deepcopy(feature_configs)\n",
        "for feature_config in rtl_layer_feature_configs:\n",
        "  feature_config.lattice_size = 2\n",
        "  feature_config.unimodality = 'none'\n",
        "  feature_config.reflects_trust_in = None\n",
        "  feature_config.dominates = None\n",
        "  feature_config.regularizer_configs = None\n",
        "# This is a calibrated lattice ensemble model: inputs are calibrated, then\n",
        "# combined non-linearly and averaged using multiple lattice layers.\n",
        "rtl_layer_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n",
        "    feature_configs=rtl_layer_feature_configs,\n",
        "    lattices='rtl_layer',\n",
        "    num_lattices=5,\n",
        "    lattice_rank=3,\n",
        "    output_min=min_label,\n",
        "    output_max=max_label - numerical_error_epsilon,\n",
        "    output_initialization=[min_label, max_label],\n",
        "    random_seed=42)\n",
        "# A CalibratedLatticeEnsemble premade model constructed from the given\n",
        "# model config. Note that we do not have to specify the lattices by calling\n",
        "# a helper function (like before with random) because the RTL Layer will take\n",
        "# care of that for us.\n",
        "rtl_layer_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(\n",
        "    rtl_layer_ensemble_model_config)\n",
        "# Let's plot our model.\n",
        "tf.keras.utils.plot_model(\n",
        "    rtl_layer_ensemble_model, show_layer_names=False, rankdir='LR')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yWdxZpS0JWag"
      },
      "source": [
        "As before, we compile, fit, and evaluate our model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HQdkkWwqJW8p"
      },
      "outputs": [],
      "source": [
        "rtl_layer_ensemble_model.compile(\n",
        "    loss=tf.keras.losses.BinaryCrossentropy(),\n",
        "    metrics=[tf.keras.metrics.AUC()],\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))\n",
        "rtl_layer_ensemble_model.fit(\n",
        "    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)\n",
        "print('Test Set Evaluation...')\n",
        "print(rtl_layer_ensemble_model.evaluate(test_xs, test_ys))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A61VpAl8uOiT"
      },
      "source": [
        "### Crystals Lattice Ensemble\n",
        "\n",
        "Premade also provides a heuristic feature arrangement algorithm, called [Crystals](https://papers.nips.cc/paper/6377-fast-and-flexible-monotonic-functions-with-ensembles-of-lattices). To use the Crystals algorithm, first we train a prefitting model that estimates pairwise feature interactions. We then arrange the final ensemble such that features with more non-linear interactions are in the same lattices.\n",
        "\n",
        "the Premade Library offers helper functions for constructing the prefitting model configuration and extracting the crystals structure. Note that the prefitting model does not need to be fully trained, so a few epochs should be enough.\n",
        "\n",
        "This example creates a calibrated lattice ensemble model with 5 lattice and 3 features per lattice."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yT5eiknCu9sj"
      },
      "outputs": [],
      "source": [
        "# This is a calibrated lattice ensemble model: inputs are calibrated, then\n",
        "# combines non-linearly and averaged using multiple lattice layers.\n",
        "crystals_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n",
        "    feature_configs=feature_configs,\n",
        "    lattices='crystals',\n",
        "    num_lattices=5,\n",
        "    lattice_rank=3,\n",
        "    output_min=min_label,\n",
        "    output_max=max_label - numerical_error_epsilon,\n",
        "    output_initialization=[min_label, max_label],\n",
        "    random_seed=42)\n",
        "# Now that we have our model config, we can construct a prefitting model config.\n",
        "prefitting_model_config = tfl.premade_lib.construct_prefitting_model_config(\n",
        "    crystals_ensemble_model_config)\n",
        "# A CalibratedLatticeEnsemble premade model constructed from the given\n",
        "# prefitting model config.\n",
        "prefitting_model = tfl.premade.CalibratedLatticeEnsemble(\n",
        "    prefitting_model_config)\n",
        "# We can compile and train our prefitting model as we like.\n",
        "prefitting_model.compile(\n",
        "    loss=tf.keras.losses.BinaryCrossentropy(),\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))\n",
        "prefitting_model.fit(\n",
        "    train_xs,\n",
        "    train_ys,\n",
        "    epochs=PREFITTING_NUM_EPOCHS,\n",
        "    batch_size=BATCH_SIZE,\n",
        "    verbose=False)\n",
        "# Now that we have our trained prefitting model, we can extract the crystals.\n",
        "tfl.premade_lib.set_crystals_lattice_ensemble(crystals_ensemble_model_config,\n",
        "                                              prefitting_model_config,\n",
        "                                              prefitting_model)\n",
        "# A CalibratedLatticeEnsemble premade model constructed from the given\n",
        "# model config.\n",
        "crystals_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(\n",
        "    crystals_ensemble_model_config)\n",
        "# Let's plot our model.\n",
        "tf.keras.utils.plot_model(\n",
        "    crystals_ensemble_model, show_layer_names=False, rankdir='LR')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PRLU1z-216h8"
      },
      "source": [
        "As before, we compile, fit, and evaluate our model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U73On3v91-Qq"
      },
      "outputs": [],
      "source": [
        "crystals_ensemble_model.compile(\n",
        "    loss=tf.keras.losses.BinaryCrossentropy(),\n",
        "    metrics=[tf.keras.metrics.AUC()],\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))\n",
        "crystals_ensemble_model.fit(\n",
        "    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)\n",
        "print('Test Set Evaluation...')\n",
        "print(crystals_ensemble_model.evaluate(test_xs, test_ys))"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "premade_models.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
